In [ ]:
from datetime import datetime
import time
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier, HistGradientBoostingClassifier
from sklearn import metrics

from evidently.report import Report
from evidently.metrics import DataDriftTable, DatasetDriftMetric

import shap
import xgboost as xgb
import pickle
import graphviz
from tableone import TableOne, load_dataset
from IPython.display import Latex

from counterfactuals import *
C:\Users\boris\anaconda3\envs\INNO\lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
In [ ]:
shap.initjs()
pd.set_option('display.max_columns', None)
warnings.filterwarnings('ignore')
No description has been provided for this image

XAI-trust overzicht explanations¶

Dit notebook dient als toelichting bij alle explanations die het XAI-trust project gemaakt heeft aan de hand van een model dat sepsis-geassocieerd delirium voorspelt. Het is niet nodig om de paper over het model te lezen, maar kan wellicht wel context bieden: https://doi.org/10.1038/s41598-023-38650-4

In de markdown cellen staat informatie over hoe de explanation afgelezen dient te worden. Sommige stukjes code zijn voorzien van comments die hulp kunnen bieden bij de implementatie op de dashboard.

Table Of Contents¶

  • DATA PREP + MODELLEN
  • GLOBAL EXPLANATIONS
    • TableOne
    • Correlation Matrix
    • Confusion Matrix
    • XGB functionality
      • Feature importance
      • Graphviz
    • SHAP
      • Interactive
      • Bar
      • Beeswarm
    • Data Drift
  • LOCAL EXPLANATIONS
    • Counterfactuals
      • Feature select
      • Genetic
      • KDTree
    • SHAP
      • Force
      • Waterfall
      • Decision

Glossary¶

Categoriaal: Wanneer een feature categoriaal is kunnen we de waarden van die feature niet sorteren op grootte. Een feature 'haarkleur' zou bijvoorbeeld categoriaal zijn.

Correlatie: Ook wel 'lineair verband'. Wanneer twee features een lineaire relatie hebben verwacht je dat een verandering in één variabele een constante verandering in de andere variabele veroorzaakt.

Data drift: ook wel 'Model Drift'. Een model wordt getraind op een set op een bepaald moment, maar de echte wereld veranderd constant; een populatie kan bijvoorbeeld ouder worden, een ecosysteem kan opwarmen, medische techniek kan verbeteren. De mate waarin een model niet meer past op de nieuwe werkelijkheid (dus de mate waarin de test set en de huidige set van elkaar afwijken) noemen we data drift

(Data)set: Grote hoeveelheid gegevenspunten, denk bijvoorbeeld aan een excel sheet. vaak bij het trainen van een machine learning model wordt een set opgedeeld in een test- en een train set

  • traintset: set met data waarop het getraind wordt
  • testset: set die gebruikt wordt om het model te valideren

(Decision)tree: Ook wel 'keuzeboom'. Het is als het ware een soort flowchart (zie 'split') die wordt afgelopen om tot een predictie te komen. XGBoost bestaat uit meerdere decisiontrees die samen tot een voorspelling komen.

Feature: Een kenmerk of eigenschap, aan de hand waarvan wij een voorspelling kunnen maken. Als wij bijvoorbeeld de prijs van een huis willen voorspellen dan kan het woonoppervlak een feature daarvoor zijn.

Kolom: Verticale opeenvolging van informatie-cellen in tabel.

Leave: Als het ware een 'eindpunt' van een decision tree; de keuze waar die op land.

Machine learning: Subcategorie van artificial intelligence. Waar artificial intelligence ook een 'dom' algoritme kan zijn zoals bijvoorbeeld een NPC in een spel, duidt machine learning specifiek op een algoritme wat kan leren.

Mean: Engels voor gemiddelde; het totaal van gegevenspunten gedeeld door het aantaantal gegevenspunten.

Missing: Ontbrekende waarden worden zo aangegeven. In een set staat het ook vaak aangegeven als 'NaN' of 'None'

Model: Een algoritme dat de echte wereld probeert na te bootsen, als het ware probeert te modelleren. In ons geval hebben we een XGBoost model getrains op een hele hoop patiënten met en zonder SAD, zodat als wij een niewe patiënt aandragen het algoritme modelleert of de nieuwe patient SAD heeft.

N: Wordt vaak gebruikt om een discrete grootheid aan te geven. Bijvoorbeeld in een onderzoek waarbij 500 meetpunten genomen zijn kan er 'n = 500' staan.

Ordinaal: Wanneer een feature ordinaal is, kunnen we diens waarden sorteren naar grootte. Bijvoorbeeld een feature 'leeftijd' is ordinaal want we kunnen stellen de leeftijd 45 'groter' is dan de leeftijd 32.

Proxy: Wanneer een feature een proxy is voor een ander, dan is die feature een indirecte indicatie van de ander. Bijvoorbeeld een feature 'burgerlijke staat' zou een proxy kunnen zijn voor een feature 'leeftijd', iemands burgerlijke staat is immers een indicatie van iemands leeftijdscategorie.

Record: Hier mee wordt meestal een rij in een dataset bedoeld.

Rij: Horizontale opeenvolging van informatie-cellen in tabel.

Sample: In dit document wordt hier meestal een subset van de dataset bedoeld.

Sepsis associated delirium: Vaak afgekort als 'SAD', de aandoening die ons model probeert te voorspellen. Soms in dit notebook wordt 'SAD' gebruikt als indicatie dat iemand SAD positief is en 'NON-SAD' als de voorspelling negatief is.

SHAP waarde: Mate waarin een bepaalde feature bijdraagt aan een bepaalde voorspelling. SHAP is kort voor SHapley Additive exPlanations

Split: Binnen een decision tree maak je keuzes aan de hand van een waarde van een feature. Stel je voor we maken een flowchart om het weer te bepalen, je zou dan een keuzemoment hebben waarop je bepaald of het regent of niet; de 'feature' regen kan de 'waarde' wel of niet aannemen. Deze keuze binnen een keuzeboom noemen we een split.

Standaarddeviatie: Vaak afgekort als 'SD', een maat voor de spreiding van gegevenspunten in een dataset rondom het gemiddelde. Een hoge standaarddeviatie indiceert een hoge spreiding van de data.

Verwachtingswaarde: Ook wel gewogen gemiddelde. De waarde die een datapunt gemiddeld aanneemt. Bijvoorbeeld als we een normale dobbelsteen gooien is het gemiddelde van de ogen 3.5 dus de verwachtingswaarde is ook 3.5. Als we een gewogen dobbelsteen gooien waarbij de 6 twee keer zo vaak voorkomt, is het gemiddelde van de ogen nog steeds 3.5, maar de gemiddelde waarde ligt hoger dus de verwachtingswaarde is (1+2+3+4+5+2*6)/7 = 3.86.

X: Invoerwaarde van een functie. Wanneer het gaat om modellen die grote hoeveelheden data tegelijk verwerken kan 'X' dus ook staan voor een hele dataset die in een keer in het model gestopt wordt

XGBoost: Een machine learning model die meerdere decision trees genereert. Aan de hand van de decision trees komt het model tot diens predictie.

y: Uitvoerwaarde van een functie. In ons geval is dat dus de predictie, dus of iemand wel of geen SAD heeft.

DATA + MODELLEN ¶

In [ ]:
data_raw = pd.read_stata('MIMIC-SAD_dta_files/MIMIC-IV.dta')
data_cf = data_raw.drop(['deliriumtime', 'hosp_mort', 'icu28dmort', 'stay_id', 'icustay', 'hospstay', 'sepsistime'], axis=1).dropna()
dummies = pd.get_dummies(data_cf['race'])
data = data_cf.drop('race',axis=1).join(dummies)
dummies = pd.get_dummies(data['first_careunit'])
data = data.drop('first_careunit',axis=1).join(dummies)
xgb_matrix = xgb.DMatrix(data.drop(['sad'], axis=1))

data_cf = data_cf.drop(['sad'], axis=1)

data_table = data_raw.drop('stay_id', axis=1)
data_table['gender'] = data_table['gender'].replace(to_replace = 0.0, value = 'FEMALE')
data_table['gender'] = data_table['gender'].replace(to_replace = 1.0, value = 'MALE')

data_table['sad'] = data_table['sad'].replace(to_replace = 0.0, value = 'NON-SAD')
data_table['sad'] = data_table['sad'].replace(to_replace = 1.0, value = 'SAD')

for i in ['vent', 'crrt', 'vaso', 'seda', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'aki', 'stroke','hosp_mort']:
    data_table[i] = data_table[i].replace(to_replace = 0.0, value = 'FALSE')
    data_table[i] = data_table[i].replace(to_replace = 1.0, value = 'TRUE')

data_raw: data direct uit het .dta bestand van de SAD repo

In [ ]:
data_raw
Out[ ]:
stay_id age weight gender race first_careunit temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm sad aki stroke hosp_mort icustay hospstay deliriumtime sepsistime icu28dmort
0 30000646 44.0 79.000000 0 AISAN CCU 37.000000 100.0 28.0 98.0 107.0 66.0 75.0 8.5 12.9 268.0 12.0 0.900000 102.0 138.0 105.0 3.5 2.2 7.8 3.4 1.3 14.500000 37.400002 25.0 12.0 15.0 0.0 0.0 1.0 0.0 3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 113.0 200.0 NaN 10.0 0.0
1 30001446 56.0 119.300003 0 WHITE MICU 36.720001 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.700000 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.400000 38.400002 15.0 14.0 15.0 0.0 0.0 1.0 0.0 8 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 40.0 147.0 NaN 1.0 0.0
2 30002415 73.0 83.500000 1 WHITE CVICU 36.439999 71.0 16.0 100.0 117.0 67.0 87.0 6.7 10.4 96.0 9.0 0.600000 170.0 136.0 111.0 4.5 3.2 NaN NaN 1.8 20.000000 37.700001 21.0 6.0 15.0 0.0 0.0 1.0 1.0 4 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 25.0 137.0 NaN 3.0 0.0
3 30003226 67.0 93.449997 0 BLACK SICU 37.220001 89.0 17.0 98.0 111.0 63.0 71.0 8.3 7.3 225.0 63.0 18.200001 117.0 135.0 93.0 6.8 1.9 8.6 6.2 NaN NaN NaN 24.0 25.0 15.0 0.0 1.0 0.0 0.0 4 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 47.0 322.0 NaN 40.0 0.0
4 30004242 76.0 77.599998 1 BLACK TSICU 36.720001 59.0 21.0 97.0 107.0 90.0 94.0 9.4 11.0 280.0 10.0 0.500000 123.0 136.0 100.0 3.3 1.5 9.1 3.6 1.0 11.300000 24.900000 24.0 15.0 15.0 0.0 0.0 0.0 0.0 3 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 43.0 182.0 NaN 15.0 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
14615 39993425 93.0 47.799999 1 WHITE MICU 35.830002 115.0 16.0 99.0 97.0 71.0 80.0 6.5 11.6 119.0 45.0 0.900000 121.0 152.0 119.0 3.6 2.3 8.1 2.8 1.7 18.299999 29.500000 19.0 13.0 8.0 0.0 0.0 1.0 0.0 3 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 63.0 84.0 2.0 2.0 0.0
14616 39993476 67.0 93.000000 0 WHITE CVICU 36.439999 81.0 16.0 100.0 112.0 60.0 78.0 13.1 13.6 166.0 12.0 0.700000 113.0 135.0 105.0 4.5 2.3 8.2 2.5 1.2 13.300000 23.400000 24.0 10.0 15.0 1.0 0.0 0.0 1.0 2 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 24.0 100.0 NaN 13.0 0.0
14617 39993968 91.0 57.500000 1 WHITE CCU 35.830002 43.0 17.0 100.0 78.0 39.0 49.0 15.8 14.7 258.0 33.0 1.500000 177.0 132.0 98.0 5.3 2.0 8.8 5.1 1.0 12.400000 26.400000 23.0 16.0 15.0 0.0 0.0 1.0 0.0 4 0.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 142.0 143.0 25.0 3.0 0.0
14618 39996044 59.0 66.400002 0 WHITE MICU 36.389999 105.0 23.0 100.0 107.0 63.0 80.0 3.7 8.0 15.0 20.0 0.500000 161.0 139.0 105.0 4.0 1.9 7.6 4.5 1.3 13.900000 26.100000 26.0 12.0 15.0 1.0 0.0 1.0 0.0 3 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 169.0 741.0 15.0 3.0 0.0
14619 39999301 78.0 107.699997 0 BLACK CVICU 36.610001 58.0 15.0 96.0 108.0 62.0 73.0 9.3 12.5 197.0 17.0 1.500000 114.0 142.0 109.0 3.4 2.1 8.8 3.4 1.1 13.400000 26.500000 24.0 12.0 15.0 0.0 0.0 0.0 1.0 2 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 28.0 25.0 NaN 24.0 1.0

14620 rows × 50 columns

data_cf: data voor counterfactual modellen (update: deprecated, counterfactuals debruikt nu data) (update 2: deze wordt nu wel gebruikt voor outlier detection)

  • geen one-hot encoding
  • NaN rijen gedropt
  • ongebruikte features gedropt ('hosp_mort' etc.)
  • target ('sad') gedropt
In [ ]:
data_cf 
Out[ ]:
age weight gender race first_careunit temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm aki stroke
0 44.0 79.000000 0 AISAN CCU 37.000000 100.0 28.0 98.0 107.0 66.0 75.0 8.500000 12.9 268.0 12.0 0.9 102.0 138.0 105.0 3.5 2.2 7.8 3.4 1.3 14.500000 37.400002 25.0 12.0 15.0 0.0 0.0 1.0 0.0 3 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 56.0 119.300003 0 WHITE MICU 36.720001 82.0 22.0 90.0 75.0 56.0 61.0 13.000000 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.400000 38.400002 15.0 14.0 15.0 0.0 0.0 1.0 0.0 8 0.0 0.0 0.0 0.0 0.0 1.0 0.0
4 76.0 77.599998 1 BLACK TSICU 36.720001 59.0 21.0 97.0 107.0 90.0 94.0 9.400000 11.0 280.0 10.0 0.5 123.0 136.0 100.0 3.3 1.5 9.1 3.6 1.0 11.300000 24.900000 24.0 15.0 15.0 0.0 0.0 0.0 0.0 3 0.0 0.0 0.0 1.0 0.0 0.0 0.0
5 83.0 72.000000 0 WHITE SICU 36.330002 109.0 16.0 100.0 111.0 63.0 79.0 4.800000 13.3 307.0 62.0 2.8 108.0 136.0 108.0 3.6 2.1 6.4 4.1 1.4 16.200001 26.900000 18.0 14.0 15.0 1.0 0.0 1.0 1.0 3 0.0 0.0 0.0 1.0 0.0 1.0 0.0
6 57.0 77.500000 0 WHITE MICU 38.669998 101.0 23.0 99.0 130.0 84.0 93.0 17.200001 15.1 261.0 25.0 1.0 100.0 138.0 105.0 4.3 2.0 8.5 4.0 1.2 13.500000 33.799999 21.0 16.0 13.0 1.0 0.0 0.0 1.0 3 0.0 0.0 0.0 0.0 0.0 1.0 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
14615 93.0 47.799999 1 WHITE MICU 35.830002 115.0 16.0 99.0 97.0 71.0 80.0 6.500000 11.6 119.0 45.0 0.9 121.0 152.0 119.0 3.6 2.3 8.1 2.8 1.7 18.299999 29.500000 19.0 13.0 8.0 0.0 0.0 1.0 0.0 3 0.0 0.0 0.0 0.0 0.0 0.0 0.0
14616 67.0 93.000000 0 WHITE CVICU 36.439999 81.0 16.0 100.0 112.0 60.0 78.0 13.100000 13.6 166.0 12.0 0.7 113.0 135.0 105.0 4.5 2.3 8.2 2.5 1.2 13.300000 23.400000 24.0 10.0 15.0 1.0 0.0 0.0 1.0 2 0.0 0.0 0.0 1.0 0.0 1.0 0.0
14617 91.0 57.500000 1 WHITE CCU 35.830002 43.0 17.0 100.0 78.0 39.0 49.0 15.800000 14.7 258.0 33.0 1.5 177.0 132.0 98.0 5.3 2.0 8.8 5.1 1.0 12.400000 26.400000 23.0 16.0 15.0 0.0 0.0 1.0 0.0 4 0.0 0.0 0.0 0.0 0.0 1.0 0.0
14618 59.0 66.400002 0 WHITE MICU 36.389999 105.0 23.0 100.0 107.0 63.0 80.0 3.700000 8.0 15.0 20.0 0.5 161.0 139.0 105.0 4.0 1.9 7.6 4.5 1.3 13.900000 26.100000 26.0 12.0 15.0 1.0 0.0 1.0 0.0 3 0.0 0.0 0.0 0.0 0.0 0.0 0.0
14619 78.0 107.699997 0 BLACK CVICU 36.610001 58.0 15.0 96.0 108.0 62.0 73.0 9.300000 12.5 197.0 17.0 1.5 114.0 142.0 109.0 3.4 2.1 8.8 3.4 1.1 13.400000 26.500000 24.0 12.0 15.0 0.0 0.0 0.0 1.0 2 0.0 1.0 0.0 0.0 0.0 1.0 0.0

11196 rows × 42 columns

data_table: data voor TableOne en de correlatie matrix: een aantal features zijn aangepast om het aangenamer lezen te maken (bijv 1.0 vervangen voor 'TRUE', als het een boolean kolom is)

  • geen one-hot encoding
  • geen NaN rijen gedropt
  • geen ongebruikte features gedropt, behalve 'stay_id': dit is een index; data-analyse hierop doen zou onzinnig zijn
  • geen target gedropt
In [ ]:
data_table
Out[ ]:
age weight gender race first_careunit temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm sad aki stroke hosp_mort icustay hospstay deliriumtime sepsistime icu28dmort
0 44.0 79.000000 FEMALE AISAN CCU 37.000000 100.0 28.0 98.0 107.0 66.0 75.0 8.5 12.9 268.0 12.0 0.900000 102.0 138.0 105.0 3.5 2.2 7.8 3.4 1.3 14.500000 37.400002 25.0 12.0 15.0 FALSE FALSE TRUE FALSE 3 FALSE FALSE FALSE FALSE FALSE NON-SAD FALSE FALSE TRUE 113.0 200.0 NaN 10.0 0.0
1 56.0 119.300003 FEMALE WHITE MICU 36.720001 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.700000 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.400000 38.400002 15.0 14.0 15.0 FALSE FALSE TRUE FALSE 8 FALSE FALSE FALSE FALSE FALSE NON-SAD TRUE FALSE FALSE 40.0 147.0 NaN 1.0 0.0
2 73.0 83.500000 MALE WHITE CVICU 36.439999 71.0 16.0 100.0 117.0 67.0 87.0 6.7 10.4 96.0 9.0 0.600000 170.0 136.0 111.0 4.5 3.2 NaN NaN 1.8 20.000000 37.700001 21.0 6.0 15.0 FALSE FALSE TRUE TRUE 4 FALSE FALSE FALSE TRUE FALSE NON-SAD TRUE FALSE FALSE 25.0 137.0 NaN 3.0 0.0
3 67.0 93.449997 FEMALE BLACK SICU 37.220001 89.0 17.0 98.0 111.0 63.0 71.0 8.3 7.3 225.0 63.0 18.200001 117.0 135.0 93.0 6.8 1.9 8.6 6.2 NaN NaN NaN 24.0 25.0 15.0 FALSE TRUE FALSE FALSE 4 FALSE FALSE FALSE FALSE FALSE NON-SAD TRUE FALSE FALSE 47.0 322.0 NaN 40.0 0.0
4 76.0 77.599998 MALE BLACK TSICU 36.720001 59.0 21.0 97.0 107.0 90.0 94.0 9.4 11.0 280.0 10.0 0.500000 123.0 136.0 100.0 3.3 1.5 9.1 3.6 1.0 11.300000 24.900000 24.0 15.0 15.0 FALSE FALSE FALSE FALSE 3 FALSE FALSE FALSE TRUE FALSE NON-SAD FALSE FALSE FALSE 43.0 182.0 NaN 15.0 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
14615 93.0 47.799999 MALE WHITE MICU 35.830002 115.0 16.0 99.0 97.0 71.0 80.0 6.5 11.6 119.0 45.0 0.900000 121.0 152.0 119.0 3.6 2.3 8.1 2.8 1.7 18.299999 29.500000 19.0 13.0 8.0 FALSE FALSE TRUE FALSE 3 FALSE FALSE FALSE FALSE FALSE SAD FALSE FALSE TRUE 63.0 84.0 2.0 2.0 0.0
14616 67.0 93.000000 FEMALE WHITE CVICU 36.439999 81.0 16.0 100.0 112.0 60.0 78.0 13.1 13.6 166.0 12.0 0.700000 113.0 135.0 105.0 4.5 2.3 8.2 2.5 1.2 13.300000 23.400000 24.0 10.0 15.0 TRUE FALSE FALSE TRUE 2 FALSE FALSE FALSE TRUE FALSE NON-SAD TRUE FALSE FALSE 24.0 100.0 NaN 13.0 0.0
14617 91.0 57.500000 MALE WHITE CCU 35.830002 43.0 17.0 100.0 78.0 39.0 49.0 15.8 14.7 258.0 33.0 1.500000 177.0 132.0 98.0 5.3 2.0 8.8 5.1 1.0 12.400000 26.400000 23.0 16.0 15.0 FALSE FALSE TRUE FALSE 4 FALSE FALSE FALSE FALSE FALSE SAD TRUE FALSE FALSE 142.0 143.0 25.0 3.0 0.0
14618 59.0 66.400002 FEMALE WHITE MICU 36.389999 105.0 23.0 100.0 107.0 63.0 80.0 3.7 8.0 15.0 20.0 0.500000 161.0 139.0 105.0 4.0 1.9 7.6 4.5 1.3 13.900000 26.100000 26.0 12.0 15.0 TRUE FALSE TRUE FALSE 3 FALSE FALSE FALSE FALSE FALSE SAD FALSE FALSE FALSE 169.0 741.0 15.0 3.0 0.0
14619 78.0 107.699997 FEMALE BLACK CVICU 36.610001 58.0 15.0 96.0 108.0 62.0 73.0 9.3 12.5 197.0 17.0 1.500000 114.0 142.0 109.0 3.4 2.1 8.8 3.4 1.1 13.400000 26.500000 24.0 12.0 15.0 FALSE FALSE FALSE TRUE 2 FALSE TRUE FALSE FALSE FALSE NON-SAD TRUE FALSE TRUE 28.0 25.0 NaN 24.0 1.0

14620 rows × 49 columns

data: data voor alle andere explanations

  • one-hot encoding
  • NaN rijen gedropt
  • ongebruikte features gedropt
  • geen target gedropt
In [ ]:
data # data voor alle andere explanations: wel one-hot, wel NaN gedropt, wel ongebruikte features gedropt, geen target gedropt
Out[ ]:
age weight gender temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm sad aki stroke AISAN BLACK HISPANIC OTHER WHITE unknown CCU CVICU MICU MICU/SICU NICU SICU TSICU
0 44.0 79.000000 0 37.000000 100.0 28.0 98.0 107.0 66.0 75.0 8.500000 12.9 268.0 12.0 0.9 102.0 138.0 105.0 3.5 2.2 7.8 3.4 1.3 14.500000 37.400002 25.0 12.0 15.0 0.0 0.0 1.0 0.0 3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1 0 0 0 0 0 1 0 0 0 0 0 0
1 56.0 119.300003 0 36.720001 82.0 22.0 90.0 75.0 56.0 61.0 13.000000 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.400000 38.400002 15.0 14.0 15.0 0.0 0.0 1.0 0.0 8 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0
4 76.0 77.599998 1 36.720001 59.0 21.0 97.0 107.0 90.0 94.0 9.400000 11.0 280.0 10.0 0.5 123.0 136.0 100.0 3.3 1.5 9.1 3.6 1.0 11.300000 24.900000 24.0 15.0 15.0 0.0 0.0 0.0 0.0 3 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0 1 0 0 0 0 0 0 0 0 0 0 1
5 83.0 72.000000 0 36.330002 109.0 16.0 100.0 111.0 63.0 79.0 4.800000 13.3 307.0 62.0 2.8 108.0 136.0 108.0 3.6 2.1 6.4 4.1 1.4 16.200001 26.900000 18.0 14.0 15.0 1.0 0.0 1.0 1.0 3 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0 0 0 0 1 0 0 0 0 0 0 1 0
6 57.0 77.500000 0 38.669998 101.0 23.0 99.0 130.0 84.0 93.0 17.200001 15.1 261.0 25.0 1.0 100.0 138.0 105.0 4.3 2.0 8.5 4.0 1.2 13.500000 33.799999 21.0 16.0 13.0 1.0 0.0 0.0 1.0 3 0.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
14615 93.0 47.799999 1 35.830002 115.0 16.0 99.0 97.0 71.0 80.0 6.500000 11.6 119.0 45.0 0.9 121.0 152.0 119.0 3.6 2.3 8.1 2.8 1.7 18.299999 29.500000 19.0 13.0 8.0 0.0 0.0 1.0 0.0 3 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0
14616 67.0 93.000000 0 36.439999 81.0 16.0 100.0 112.0 60.0 78.0 13.100000 13.6 166.0 12.0 0.7 113.0 135.0 105.0 4.5 2.3 8.2 2.5 1.2 13.300000 23.400000 24.0 10.0 15.0 1.0 0.0 0.0 1.0 2 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0 0 0 0 1 0 0 1 0 0 0 0 0
14617 91.0 57.500000 1 35.830002 43.0 17.0 100.0 78.0 39.0 49.0 15.800000 14.7 258.0 33.0 1.5 177.0 132.0 98.0 5.3 2.0 8.8 5.1 1.0 12.400000 26.400000 23.0 16.0 15.0 0.0 0.0 1.0 0.0 4 0.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0 0 0 0 1 0 1 0 0 0 0 0 0
14618 59.0 66.400002 0 36.389999 105.0 23.0 100.0 107.0 63.0 80.0 3.700000 8.0 15.0 20.0 0.5 161.0 139.0 105.0 4.0 1.9 7.6 4.5 1.3 13.900000 26.100000 26.0 12.0 15.0 1.0 0.0 1.0 0.0 3 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0
14619 78.0 107.699997 0 36.610001 58.0 15.0 96.0 108.0 62.0 73.0 9.300000 12.5 197.0 17.0 1.5 114.0 142.0 109.0 3.4 2.1 8.8 3.4 1.1 13.400000 26.500000 24.0 12.0 15.0 0.0 0.0 0.0 1.0 2 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0 1 0 0 0 0 0 1 0 0 0 0 0

11196 rows × 54 columns

In [ ]:
model = pickle.load(open("xgb.pkl", "rb"))

cf_random = pickle.load(open("cf_random.pkl", "rb"))
cf_genetic = pickle.load(open("cf_genetic.pkl", "rb"))
cf_kdtree = pickle.load(open("cf_kdtree.pkl", "rb"))

explainer = shap.TreeExplainer(model)
explainer_waterfall = shap.Explainer(model, data)  # het lijkt er op dat we beter de waterfall plot niet kunnen gebruiken, in dat geval zou deze dus ook niet nodig zijn
shap_values = explainer.shap_values(data.loc[:, ~data.columns.isin(["sad"])])
shap_waterfall = explainer_waterfall(data)
 94%|=================== | 10569/11196 [00:12<00:00]       

GLOBAL EXPLANATIONS ¶

Sommige explanations in deze categorie (confusion matrix, Table1) kunnen meerdere keren in het dashboard: een keer voor de oorspronkelijke (test)set, en een keer voor de set die op dat moment aan het dashboard gekoppeld is.

TableOne ¶

Gebaseerd op de Table1 package van de programmeer taal R. Deze tabel geeft een algemeen overzicht van de beschikbare data.

Hoe je dit afleest: In de linker kolom staat om welke variabele het gaat. Hierbinnen heb je twee soorten variabelen:

  • ordinale variabelen: De eerste waarde in de drie laatste kolommen geeft het gemiddelde aan binnen de categorie van die kolom. De tweede waarde, die tussen haakjes, geeft de standaarddeviatie (SD) weer. Bijvoorbeeld 'age' heeft over de gehele dataset een gemiddelde van 66.9 en een SD van 15.9. bij mensen die SAD hebben is de gemiddelde 'age' 67.3 met een SD van 16.1.
  • categoriale variabelen: De tweede kolom specificeert welke waarde deze variabele aanneemt. De linker waarde geeft aan hoe vaak deze waarde voorkomt binnen de respectievelijke kolom, de rechter waarde laat zien welk percentage dit is van de volledige set. Bijvoorbeeld 'gender' neemt binnen de hele set 8518 de waarde 'FEMALE' aan, dit is 58.3% van de hele set. Hier tegenover staat dat er 6102 'MALE' zijn, dus 41.7%. Binnen de SAD patienten is 57.6% 'MALE' en 42.4% 'FEMALE'

De kolom 'missing' geeft aan hoe vaak de variabele geen waarde aanneemt binnen de set. De rij 'n' is niet een variabele, dit gaat over het totaal aantal meetpunten binnen de categorie van de kolom.

De rijen 'deliriumtime', 'hosp_mort', 'icu28dmort', 'icustay', 'hospstay' en 'sepsistime' zijn geen features in het model, deze informatie is immers niet aanwezig op het moment van voorspellen; ze staan in de tabel omdat dit wel handige informatie is.

In [ ]:
categorical = ['gender', 'race', 'first_careunit', 'vent', 'crrt', 'vaso', 'seda', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'sad', 'aki', 'stroke','hosp_mort']
groupby = ['sad']
table1 = TableOne(data_table, categorical=categorical, groupby=groupby, pval=False)
In [ ]:
print(table1.tabulate(tablefmt = "fancy_grid")) # je kan "fancy_grid" vervangen voor "html" als dit makkelijker is voor de dashboard
╒═════════════════════════╤═══════════╤═══════════╤═══════════════╤═══════════════╤═══════════════╕
│                         │           │ Missing   │ Overall       │ NON-SAD       │ SAD           │
╞═════════════════════════╪═══════════╪═══════════╪═══════════════╪═══════════════╪═══════════════╡
│ n                       │           │           │ 14620         │ 9230          │ 5390          │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ age, mean (SD)          │           │ 0         │ 66.9 (15.9)   │ 66.7 (15.8)   │ 67.3 (16.1)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ weight, mean (SD)       │           │ 160       │ 83.1 (23.6)   │ 83.1 (23.0)   │ 83.2 (24.6)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ gender, n (%)           │ FEMALE    │ 0         │ 8518 (58.3)   │ 5416 (58.7)   │ 3102 (57.6)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ MALE      │           │ 6102 (41.7)   │ 3814 (41.3)   │ 2288 (42.4)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ race, n (%)             │ AISAN     │ 0         │ 426 (2.9)     │ 311 (3.4)     │ 115 (2.1)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ BLACK     │           │ 1266 (8.7)    │ 766 (8.3)     │ 500 (9.3)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ HISPANIC  │           │ 557 (3.8)     │ 360 (3.9)     │ 197 (3.7)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ OTHER     │           │ 642 (4.4)     │ 417 (4.5)     │ 225 (4.2)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ WHITE     │           │ 9723 (66.5)   │ 6372 (69.0)   │ 3351 (62.2)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ unknown   │           │ 2006 (13.7)   │ 1004 (10.9)   │ 1002 (18.6)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ first_careunit, n (%)   │ CCU       │ 0         │ 1366 (9.3)    │ 881 (9.5)     │ 485 (9.0)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ CVICU     │           │ 3461 (23.7)   │ 2772 (30.0)   │ 689 (12.8)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ MICU      │           │ 3078 (21.1)   │ 1601 (17.3)   │ 1477 (27.4)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ MICU/SICU │           │ 2706 (18.5)   │ 1780 (19.3)   │ 926 (17.2)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ NICU      │           │ 534 (3.7)     │ 234 (2.5)     │ 300 (5.6)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ SICU      │           │ 1887 (12.9)   │ 1112 (12.0)   │ 775 (14.4)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TSICU     │           │ 1588 (10.9)   │ 850 (9.2)     │ 738 (13.7)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ temperature, mean (SD)  │           │ 48        │ 36.7 (0.8)    │ 36.7 (0.8)    │ 36.8 (0.9)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ heart_rate, mean (SD)   │           │ 1         │ 89.7 (20.3)   │ 88.2 (19.6)   │ 92.3 (21.1)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ resp_rate, mean (SD)    │           │ 24        │ 19.6 (6.0)    │ 19.1 (5.9)    │ 20.6 (6.1)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ spo2, mean (SD)         │           │ 5         │ 97.1 (4.0)    │ 97.3 (3.7)    │ 96.7 (4.3)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ sbp, mean (SD)          │           │ 6         │ 120.3 (23.9)  │ 119.6 (23.1)  │ 121.4 (25.1)  │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ dbp, mean (SD)          │           │ 15        │ 66.5 (17.7)   │ 65.7 (16.8)   │ 67.9 (19.0)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ mbp, mean (SD)          │           │ 14        │ 81.5 (17.8)   │ 81.0 (17.0)   │ 82.5 (19.0)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ wbc, mean (SD)          │           │ 115       │ 13.1 (8.1)    │ 12.8 (7.8)    │ 13.7 (8.4)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ hemoglobin, mean (SD)   │           │ 96        │ 10.3 (2.2)    │ 10.3 (2.1)    │ 10.5 (2.3)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ platelet, mean (SD)     │           │ 102       │ 191.5 (106.0) │ 190.7 (105.3) │ 192.9 (107.1) │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ bun, mean (SD)          │           │ 60        │ 28.2 (22.9)   │ 26.1 (21.0)   │ 31.7 (25.6)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ cr, mean (SD)           │           │ 56        │ 1.5 (1.5)     │ 1.4 (1.5)     │ 1.6 (1.6)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ glu, mean (SD)          │           │ 66        │ 150.2 (74.4)  │ 144.9 (66.5)  │ 159.3 (85.5)  │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ Na, mean (SD)           │           │ 50        │ 137.4 (5.5)   │ 136.9 (5.0)   │ 138.2 (6.2)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ Cl, mean (SD)           │           │ 51        │ 103.8 (6.7)   │ 103.8 (6.3)   │ 103.9 (7.3)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ K, mean (SD)            │           │ 59        │ 4.3 (0.8)     │ 4.3 (0.8)     │ 4.3 (0.9)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ Mg, mean (SD)           │           │ 608       │ 2.0 (0.5)     │ 2.0 (0.5)     │ 2.0 (0.5)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ Ca, mean (SD)           │           │ 1399      │ 8.2 (0.9)     │ 8.2 (0.8)     │ 8.2 (0.9)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ P, mean (SD)            │           │ 1341      │ 3.8 (1.5)     │ 3.7 (1.3)     │ 4.0 (1.7)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ inr, mean (SD)          │           │ 1609      │ 1.5 (0.8)     │ 1.5 (0.7)     │ 1.6 (0.8)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ pt, mean (SD)           │           │ 1578      │ 17.0 (9.8)    │ 16.7 (8.8)    │ 17.5 (11.3)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ ptt, mean (SD)          │           │ 1658      │ 37.8 (22.4)   │ 37.2 (21.4)   │ 39.0 (24.0)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ bicarbonate, mean (SD)  │           │ 57        │ 22.2 (4.6)    │ 22.5 (4.3)    │ 21.7 (5.1)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ aniongap, mean (SD)     │           │ 64        │ 15.0 (4.6)    │ 14.4 (4.2)    │ 16.0 (4.9)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ gcs, mean (SD)          │           │ 2         │ 14.2 (2.4)    │ 14.3 (2.4)    │ 14.1 (2.3)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ vent, n (%)             │ FALSE     │ 0         │ 8023 (54.9)   │ 5974 (64.7)   │ 2049 (38.0)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 6597 (45.1)   │ 3256 (35.3)   │ 3341 (62.0)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ crrt, n (%)             │ FALSE     │ 0         │ 14364 (98.2)  │ 9152 (99.2)   │ 5212 (96.7)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 256 (1.8)     │ 78 (0.8)      │ 178 (3.3)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ vaso, n (%)             │ FALSE     │ 0         │ 7482 (51.2)   │ 4992 (54.1)   │ 2490 (46.2)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 7138 (48.8)   │ 4238 (45.9)   │ 2900 (53.8)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ seda, n (%)             │ FALSE     │ 0         │ 7962 (54.5)   │ 4968 (53.8)   │ 2994 (55.5)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 6658 (45.5)   │ 4262 (46.2)   │ 2396 (44.5)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ sofa_score, mean (SD)   │           │ 0         │ 3.6 (1.9)     │ 3.4 (1.7)     │ 3.9 (2.2)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ ami, n (%)              │ FALSE     │ 0         │ 12976 (88.8)  │ 8293 (89.8)   │ 4683 (86.9)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 1644 (11.2)   │ 937 (10.2)    │ 707 (13.1)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ ckd, n (%)              │ FALSE     │ 0         │ 11680 (79.9)  │ 7417 (80.4)   │ 4263 (79.1)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 2940 (20.1)   │ 1813 (19.6)   │ 1127 (20.9)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ copd, n (%)             │ FALSE     │ 0         │ 14088 (96.4)  │ 8944 (96.9)   │ 5144 (95.4)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 532 (3.6)     │ 286 (3.1)     │ 246 (4.6)     │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ hyperte, n (%)          │ FALSE     │ 0         │ 8322 (56.9)   │ 5158 (55.9)   │ 3164 (58.7)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 6298 (43.1)   │ 4072 (44.1)   │ 2226 (41.3)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ dm, n (%)               │ FALSE     │ 0         │ 11962 (81.8)  │ 7477 (81.0)   │ 4485 (83.2)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 2658 (18.2)   │ 1753 (19.0)   │ 905 (16.8)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ sad, n (%)              │ NON-SAD   │ 0         │ 9230 (63.1)   │ 9230 (100.0)  │               │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ SAD       │           │ 5390 (36.9)   │               │ 5390 (100.0)  │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ aki, n (%)              │ FALSE     │ 0         │ 6462 (44.2)   │ 4541 (49.2)   │ 1921 (35.6)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 8158 (55.8)   │ 4689 (50.8)   │ 3469 (64.4)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ stroke, n (%)           │ FALSE     │ 0         │ 13479 (92.2)  │ 8777 (95.1)   │ 4702 (87.2)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 1141 (7.8)    │ 453 (4.9)     │ 688 (12.8)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ hosp_mort, n (%)        │ FALSE     │ 0         │ 12768 (87.3)  │ 8527 (92.4)   │ 4241 (78.7)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│                         │ TRUE      │           │ 1852 (12.7)   │ 703 (7.6)     │ 1149 (21.3)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ icustay, mean (SD)      │           │ 0         │ 125.7 (147.7) │ 83.8 (92.8)   │ 197.4 (190.4) │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ hospstay, mean (SD)     │           │ 0         │ 311.0 (307.2) │ 256.1 (247.8) │ 404.9 (369.9) │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ deliriumtime, mean (SD) │           │ 9230      │ 44.9 (63.5)   │ nan (nan)     │ 44.9 (63.5)   │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ sepsistime, mean (SD)   │           │ 0         │ 8.2 (20.6)    │ 9.1 (23.0)    │ 6.7 (15.5)    │
├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤
│ icu28dmort, mean (SD)   │           │ 0         │ 0.1 (0.3)     │ 0.0 (0.2)     │ 0.1 (0.3)     │
╘═════════════════════════╧═══════════╧═══════════╧═══════════════╧═══════════════╧═══════════════╛

Correlation Matrix ¶

Deze matrix geeft de mate van lineaire interactie tussen variabelen weer. Als de correlatie tussen twee variabelen te hoog is kan je je dus afvragen of de een een proxy is voor de ander.

In [ ]:
corr = data_table.corr(method='pearson') #method kan aangepast worden naar 'kendall' of 'spearman', zou leuk zijn als dit interactief kan op de dashboard
mask = np.triu(np.ones_like(corr, dtype=bool))
f, ax = plt.subplots(figsize=(20, 16))
cmap = sns.color_palette("viridis_r", as_cmap=True)
sns.heatmap(
    corr, 
    mask=mask, 
    cmap=cmap, 
    vmax=.3, 
    center=0,
    square=True, 
    linewidths=.5, 
    cbar_kws={"shrink": .5}
)
Out[ ]:
<Axes: >
No description has been provided for this image

Confusion Matrix ¶

Deze matrix geeft aan hoe vaak het model een correcte/incorrecte voorspelling doet. Bijv. in dit geval kan je aflezen dat in de 3,845 gevallen van SAD, het model 1,065 keer een incorrecte voorspelling en 2,780 keer een correcte voorspelling gedaan heeft.

Als er een matrix van de train set en de huidige set te zien is kan het verschil geïnterpreteerd worden als een mate van data drift.

In [ ]:
xgb_matrix_full = xgb.DMatrix(data.loc[:, ~data.columns.isin(["sad"])], label=data["sad"])
In [ ]:
# het idee is dat deze er twee keer in staat: een keer met de oorspronkelijke testset en een keer met de set die op dat moment aan de app gekoppeld zit

xgb_pred_prob = model.predict(xgb_matrix_full) # dit is voor de volledige set. vervang `xgb_matrix_full` als je de performance van het model wrt een andere set wilt.
xgb_pred = np.where(xgb_pred_prob > 0.5, 1, 0)
xgb_pred_factor = pd.factorize(xgb_pred)[0]
test_sad_factor = pd.factorize(data["sad"])[0]

confusion_matrix = metrics.confusion_matrix(xgb_pred_factor, test_sad_factor)
print(confusion_matrix)
cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, display_labels = [False, True])

cm_display.plot()
plt.show() 
[[5656 1695]
 [1065 2780]]
No description has been provided for this image

XGB Functionality ¶

Feature importance ¶

De volgende grafieken duiden aan in hoe verre bepaalde features invloed hebben op het model, op basis van:

  • gain: de gemiddelde informatiewinst (relatieve entropie) van de splits waarin de respectievelijke feature voorkomt. Een 'split' kan gezien worden als een 'keuze' binnen de keuzeboom.
  • weight: het totaal aantal splits waarin de respectievelijke feature voorkomt in alle trees.
  • cover: gemiddelde van de 'coverage' van splits waarin de respectievelijke feature voorkomt, waarbij coverage is gedefinieerd als het aantal voorspellingen die beinvloed worden door de die split.
In [ ]:
xgb.plot_importance(model, importance_type='gain', max_num_features=30)
plt.title('feature importance: gain')
plt.show()
xgb.plot_importance(model, max_num_features=30)
plt.title('feature importance: weight')
plt.show()
xgb.plot_importance(model, importance_type='cover', max_num_features=30)
plt.title('feature importance: cover')
plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Graphviz tree visualiser ¶

Een graaf visualisatie van één tree in het XGBoost model; het model bestaat uit een groot aantal van dit soort keuzebomen. In het voorbeeld staat een visualisatie van de meest representatieve boom, maar andere trees kunnen ook worden weergegeven.

In [ ]:
xgb.to_graphviz(model, num_trees=model.best_iteration) # je kan ook andere trees visualiseren door `model.best_iteration` te vervangen voor een getal (int), zou top zijn als dit interactief kan op het dashboard
Out[ ]:
No description has been provided for this image

SHAP ¶

Uit de SHAP documentatie:

SHAP (SHapley Additive exPlanations) is a game theoretic approach to explain the output of any machine learning model. It connects optimal credit allocation with local explanations using the classic Shapley values from game theory and their related extensions.

Kort gezegd: bij elke voorspelling kunnen we een Shapley-waarde voor iedere feature bepalen, deze waarde is een indicatie voor de mate waarin de feature heeft bijgedragen aan deze individuele voorspelling.

wat kan ik met shap waarde

Interactive Plot ¶

Deze plot laat een sample uit de data (in het voorbeeld n = 500) zien over de twee gekozen assen. De plot lijkt aanvankelijk enigszins intimiderend, maar met wat intuïtie komt men er wel uit.

De dropdown aan de linkerkant van de grafiek bepaald de y-as, er kan gekozen worden voor:

  • f(x): de cumulatieve SHAP waarde
  • feature effects: de SHAP waarde van die specifieke feature

De dropdown boven de grafiek bepaald het groeperen en sorteren van de samples over de x-as. De eerste drie opties (similarity, output value, original) groeperen niet maar sorteren op de aangegeven volgorde. De feature specifieke opties groeperen de samples op elke unieke waarde van deze feature, en sorteren vervolgens deze groepen op volgorde van diens waarde in deze feature.

Rood in in de grafiek duidt op een positieve bijdrage aan de uitkomst en blauw een negatieve. Er kan ook op een plotpunt geklikt worden, dan geeft het de index van de betreffende patient, of de index van één van de patiënten binnen een groep wanneer voor de x-as een feature geselecteerd is.

Enkele voorbeelden:

  • Als we de x-as op 'sofa_score' zetten en de y-as op 'f(x)' dan zien we dat, binnen deze sample, de kans het grootst is op SAD bij een sofa score van 10. Wanneer we bij de '10' hoveren met de muis zien we dat de groep met deze score 3 groot is. Ook zien we dan de gemiddelden van een aantal features binnen deze groep, en dat 'vent' het meest positief bijdraagt en 'gcs' het meest negatief.
  • Als we de x-as op 'sample order by output value' en de y-as op 'vent effects' dan zien we dat patiënten met een een lagere outputwaarde (dus mensen die geen SAD hebben) minder vaak aan de beademing zitten.
In [ ]:
# kan zeker in het data-analysten scherm
shap.force_plot(
    explainer.expected_value, 
    shap_values[:500, :], 
    data.loc[:, ~data.columns.isin(["sad"])].iloc[:500, :],
    plot_cmap=["#FDE725", "#440154"]
)
# `500` kan aangepast worden voor een groter/kleiner sample, let wel dat een te grote sample veel lag meebrengt
# er kan ook een lijst met indices worden meegegeven als we een specifiek groepje patiënten willen vergelijken
Out[ ]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Summary plot - Bar¶

Deze plot is vergelijkbaar met de feature importance plots, maar dan dus op basis van de gemiddelde SHAP waarden van de feature

In [ ]:
shap.summary_plot(shap_values, data.loc[:, ~data.columns.isin(["sad"])], plot_type="bar")
No description has been provided for this image

Summary plot - Beeswarm ¶

De 'beeswarms' (of, sina plots) geven de dichtheid van datapunten rond bepaalde SHAP waarden van de respectievelijke feature aan. De kleur geeft de waarde van de feature zelf aan. Je kan hieruit aflezen wat voor soort waarden de features aan moeten nemen voor een grote impact op het model, en ook hoe vaak dit voor komt. Bijvoorbeeld: bij de feature 'Na' zien we bij hoge SHAP een dunne, fel rode lijn. Dit betekent dat bij een hoog 'Na' gehalte we een grote positieve invloed op de uitkomst van het model kunnen verwachten, maar dat dit niet vaak voorkomt. Aan de andere kant van de 'Na' beeswarm, net iets onder een SHAP-waarde van 0, is de sina vrij dik en blauw. Dit laat zien dat een relatief lage 'Na' waarde een licht negatieve invloed heeft op de uitkomst, maar dat dit erg vaak voor komt.

Je kan de bar summary plot zien als een compacte versie van deze plot: als we de absolute van de SHAP waarden in deze plot nemen (dus de alles aan de linker kant van de 0 lijn als het ware spiegelen), en vervolgens het gemiddelde van de beeswarms nemen, dan krijg je de bovenstaande bar plot.

In [ ]:
shap.summary_plot(shap_values, data.loc[:, ~data.columns.isin(["sad"])], cmap=plt.get_cmap("viridis"))
No description has been provided for this image

Partial dependence plots ¶

Om in te zoomen op de effectiviteit kunnen we kijken naar partial dependence plots. In de onderstaande plot zien we op de x as de waarde van de betreffende feature, en op de y as de verwachtingswaarde (gewogen gemiddelde) van de voorspelling bij die x-waarde. De horizontale stippellijn is de verwachtinswaarde van de voorspellingen op de volledige dataset, en de verticale die van de betreffende feature. de lichtgekleurde staven op de achtergrond zijn een histogram van de dataset op basis van de feature.

In het voorbeeld kijken we naar de partial dependence van 'gcs'. We zien dat de verwachtingswaarde het hoogst is bij een 'gcs' tussen de 10 en 12, dat wil zeggen dat iemand het meeste kans op SAD heeft wanneer 'gcs' tussen de 10 en 12 zit.

Deze plot is verder niet gebaseerd op SHAP, maar als we de relatie tussen 'gcs' en de bijbehorende SHAP-waarden plotten, zien we hoe goed die overeenkomen.

In [ ]:
shap.partial_dependence_plot(
    "gcs",  # 'gcs' kan worden vervangen voor een andere feature om die te bekijken
    lambda X: model.predict(xgb.DMatrix(X.drop(['sad'], axis=1))),
    data,
    ice=False,
    model_expected_value=True,
    feature_expected_value=True,
)
No description has been provided for this image

Zoals al gezegd is deze plot qua functie vergelijkbaar met de vorige.

Deze plot zegt echter ook iets over de relatie van de betreffende feature met andere features. In dit voorbeeld zien we dat wanneer de SHAP waarde voor 'gcs' relatief laag is, de patient vaak aan de beademing zit, echter bij de allerhoogste waarden die 'gcs' aanneemt dit effect omgekeerd is.

In [ ]:
shap.dependence_plot(
    'gcs', # 'gcs' kan worden vervangen voor een andere feature om die te bekijken
    shap_values, 
    data.loc[:, ~data.columns.isin(["sad"])],
    cmap=plt.get_cmap("viridis")
    # interaction_index=None
) # `interaction_index=None` kan in-gecomment worden om de kleur weg te halen. 
  # Ook kan `None` vervangen worden door een feature naam (of index) om handmatig de feature waarmee we vergelijken te kiezen. 
  # het zou wellicht leuk zijn als dit interactief kan op het dashboard
No description has been provided for this image

Nog een voorbeeld van dezelfde plot op een andere feature: we zien hier dat 'aniongap' negatief correleert met 'bicarbonate'. Dit kan een indicatie zijn dat de een een proxy is voor de ander.

In [ ]:
shap.dependence_plot(
    'aniongap', # 'gcs' kan worden vervangen voor een andere feature om die te bekijken
    shap_values, 
    data.loc[:, ~data.columns.isin(["sad"])],
    interaction_index='bicarbonate',
    cmap=plt.get_cmap("viridis")
)
No description has been provided for this image

Data Drift ¶

Maakt duidelijk in hoeverre het model afwijkt van de huidige werkelijkheid

In [ ]:
data_drift_dataset_report = Report(metrics=[
    DatasetDriftMetric(),
    DataDriftTable(),    
])

# voor het voorbeeld heb ik gwn de data in tweeën gesplitst, in de code moet 'reference_data' de data zijn waar het model op is getraind en 'current_data' de set die op dat moment aan de applicatie hangt
data_drift_dataset_report.run(reference_data=data[:int(data.shape[0]/2)], current_data=data[int(data.shape[0]/2):])
data_drift_dataset_report
Out[ ]:

LOCAL EXPLANATIONS ¶

Certainty score ¶

In [ ]:
# vervang '1' voor de index van de respectievelijke patient
score = model.predict(xgb.DMatrix(data.loc[[1], ~data.columns.isin(["sad"])], label=data["sad"]))[0]
if score < 0.5: score = 1-score
print("Certainty score: " + str(round(score*100, 2)) + "%")
Certainty score: 63.03%

Outlier detection¶

In [ ]:
data_cf.shape
Out[ ]:
(11196, 42)
In [ ]:
record = data_cf.iloc[1]
is_categorical = ['race', 'first_careunit', 'vent', 'ckd', 'crrt', 'copd', 'gender', 'vaso', 'hyperte', 'seda', 'dm', 'aki', 'ami', 'stroke']

fig, axs = plt.subplots(6, 7, figsize=(12, 6))
i = 0
j = 0
score = 0

for c in data_cf.columns:
    axs[i,j].hist(data_cf[c], bins=20)
    title = c
    if c not in is_categorical:
        title = title + "\n σ = " + str(round(data_cf[c].std(), 2)) + "\n |x-μ| = " + str(round(abs(record[c] - data_cf[c].mean()), 2))
        score += (abs(record[c] - data_cf[c].mean()) / data_cf[c].std())
        
    axs[i,j].set_title(title, fontsize=8)
    axs[i,j].get_xaxis().set_visible(False)
    axs[i,j].get_yaxis().set_visible(False)
    axs[i,j].axvline(record[c], color='r', linestyle='dashed', linewidth=1)
    i+=1
    if (i%6)==0:
        i = 0
        j += 1
score /= data_cf.shape[1]
fig.suptitle("Outlier score: " + str(score))
fig.tight_layout()
plt.show()
No description has been provided for this image

Counterfactuals ¶

Het globale idee van een counterfactual is om te redeneren over wat er zou zijn gebeurd als bepaalde omstandigheden anders waren geweest. Er wordt een hypothetisch scenario gecreëert waarin één of meerdere inputs worden veranderd waarvoor de voorspelling anders zou zijn.

Feature select ¶

Counterfactuals worden gegenereerd waarbij in één of meerdere vooraf bepaalde features gevariëerd wordt.

In het voorbeeld wordt er gevarieerd in de features age, weight, temperature en gcs. In de eerste cell zien we de oorspronkelijke record, daarna volgen 5 counterfactuals die hierop zijn gebaseerd.

Aan de tabel met counterfactuals zijn een aantal kolommen toegevoegd:

  • 'reg': uitkomst van het model, uitgaande van deze counterfactual record
  • 'pred': True/False uitkomst gebaseerd op 'reg'
  • 'fitness': fitness score aan de hand waarvan het algoritme bepaalt hoe dicht deze counterfactual bij het origineel zit. In de huidige versie is het de som van de euclidische afstanden tussen de features.
In [ ]:
data.drop(['sad'], axis=1).iloc[[1]]
Out[ ]:
age weight gender temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm aki stroke AISAN BLACK HISPANIC OTHER WHITE unknown CCU CVICU MICU MICU/SICU NICU SICU TSICU
1 56.0 119.300003 0 36.720001 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.4 38.400002 15.0 14.0 15.0 0.0 0.0 1.0 0.0 8 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0
In [ ]:
dummy_groupings = {'race':['AISAN', 'BLACK', 'HISPANIC', 'OTHER', 'WHITE', 'unknown'], 'first_careunit': ['CCU', 'CVICU', 'MICU', 'MICU/SICU', 'NICU', 'SICU', 'TSICU']}
use_feats = ['age', 'weight', 'temperature', 'gcs']
# `limit` bepaalt hoe vaak het genetisch algoritme doorlopen wordt. een hoge limit geeft counterfactuals die veel variëren maar minder op de oorspronkelijke patiënt lijken, een lage limit geeft cfs die minder variëren maar meer op de patiënt lijken.
# gezien het algoritme eerder convergeert wanneer er in minder features gevariëerd wordt, zal `limit` lager gezet moeten worden bij een kortere `use_feats`
# wellicht is het een idee om `limit` interactief te maken op het dashboard, zodat de gebruiker die zelf in kan stellen voor een gewenst resultaat.
cf_g = GeneticCounterfactual(data.drop(['sad'], axis=1), model, dummy_groupings, use_feats=use_feats, limit=1, population_size=data.drop(['sad'], axis=1).shape[0])
cf_g.generate(data.drop(['sad'], axis=1).iloc[[1]], 5)
Out[ ]:
age weight gender temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm aki stroke AISAN BLACK HISPANIC OTHER WHITE unknown CCU CVICU MICU MICU/SICU NICU SICU TSICU reg pred fitness
0 56.0 119.449997 0.0 37.110001 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.4 38.400002 15.0 14.0 14.0 0.0 0.0 1.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1 False False False False False 1 False False False False False False 0.504692 True 1.083789
1 56.0 119.449997 0.0 37.389999 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.4 38.400002 15.0 14.0 14.0 0.0 0.0 1.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1 False False False False False 1 False False False False False False 0.564591 True 1.213011
2 55.0 119.449997 0.0 37.389999 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.4 38.400002 15.0 14.0 14.0 0.0 0.0 1.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1 False False False False False 1 False False False False False False 0.557112 True 1.572067
3 55.0 119.449997 0.0 38.439999 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.4 38.400002 15.0 14.0 14.0 0.0 0.0 1.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1 False False False False False 1 False False False False False False 0.557112 True 2.231791
4 60.0 121.500000 0.0 37.169998 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.4 38.400002 15.0 14.0 14.0 0.0 0.0 1.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1 False False False False False 1 False False False False False False 0.509901 True 4.694942

Genetic ¶

Deze methode genereert op basis van een genetisch algoritme een volledig nieuwe patiënt waarvoor het model een andere voorspelling zou doen. De fictieve patiënt wordt zo gegenereerd dat die zo min mogelijk van de daadwerkelijke patiënt afwijkt.

In [ ]:
data.drop(['sad'], axis=1).iloc[[1]]
Out[ ]:
age weight gender temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm aki stroke AISAN BLACK HISPANIC OTHER WHITE unknown CCU CVICU MICU MICU/SICU NICU SICU TSICU
1 56.0 119.300003 0 36.720001 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.4 38.400002 15.0 14.0 15.0 0.0 0.0 1.0 0.0 8 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0
In [ ]:
cf_genetic = GeneticCounterfactual(data.drop(['sad'], axis=1), model, dummy_groupings, limit=10, population_size=data.shape[0])
cf_genetic.generate(data.drop(['sad'], axis=1).iloc[[1]], 5)
Out[ ]:
age weight gender temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm aki stroke AISAN BLACK HISPANIC OTHER WHITE unknown CCU CVICU MICU MICU/SICU NICU SICU TSICU reg pred fitness
0 72.0 111.500000 0.0 36.560001 81.0 21.0 94.0 72.0 55.0 67.0 19.6 10.3 40.0 72.0 1.1 81.0 129.0 106.0 3.9 1.8 7.9 4.0 1.4 20.299999 37.200001 15.0 15.0 14.0 1.0 0.0 0.0 1.0 7.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 False False False False 0 False False False 1 False False False False 0.685672 True 22.003765
1 50.0 126.699997 1.0 37.000000 84.0 16.0 97.0 85.0 55.0 61.0 12.1 7.6 36.0 73.0 1.3 84.0 130.0 102.0 3.3 1.4 9.7 3.2 1.1 13.600000 32.400002 20.0 17.0 15.0 1.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 False False False False 1 False False False 0 False False False False 0.703455 True 22.097925
2 50.0 126.699997 1.0 37.000000 83.0 16.0 97.0 85.0 55.0 61.0 12.1 7.6 36.0 74.0 1.3 84.0 130.0 94.0 3.3 2.2 8.4 3.2 1.1 17.900000 32.400002 20.0 17.0 15.0 1.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 1.0 0.0 1.0 1.0 False False False False 1 False False False 0 False False False False 0.709303 True 22.542367
3 72.0 124.949997 0.0 36.560001 90.0 22.0 97.0 72.0 55.0 61.0 12.1 9.6 31.0 72.0 2.3 81.0 123.0 101.0 3.4 1.8 8.4 0.9 1.8 20.299999 37.200001 20.0 14.0 15.0 1.0 0.0 1.0 0.0 4.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 False False False False 0 False False False 0 False False False False 0.698935 True 23.325052
4 48.0 110.599998 1.0 37.000000 79.0 16.0 97.0 84.0 62.0 61.0 12.1 7.6 36.0 73.0 1.3 84.0 130.0 102.0 3.3 1.4 8.4 3.2 1.3 13.600000 32.400002 20.0 16.0 15.0 1.0 0.0 1.0 0.0 3.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 False False False False 1 False False False 0 False False False False 0.721635 True 23.378590

KDTree ¶

Deze methode zoekt in de dataset, aan de hand van een KDTree algoritme, een bestaande patiënt waarvoor het model een andere voorspelling zou doen. De nieuwe patiënt wordt zo gekozen dat die zo min mogelijk van de huidige patiënt afwijkt.

de nieuwe kolom 'dst' geeft de euclidische afstand tussen de counterfactual en het origineel aan, het is dus vergelijkbaar met 'fitness' in de vorige.

In [ ]:
data.drop(['sad'], axis=1).iloc[[1]]
Out[ ]:
age weight gender temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm aki stroke AISAN BLACK HISPANIC OTHER WHITE unknown CCU CVICU MICU MICU/SICU NICU SICU TSICU
1 56.0 119.300003 0 36.720001 82.0 22.0 90.0 75.0 56.0 61.0 13.0 7.2 36.0 70.0 2.7 83.0 128.0 103.0 4.0 2.2 7.0 4.5 2.1 22.4 38.400002 15.0 14.0 15.0 0.0 0.0 1.0 0.0 8 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0
In [ ]:
cf_kdtree = KDTreeCounterFactual(data.drop(['sad'], axis=1), model)
cf_kdtree.generate(data.drop(['sad'], axis=1).iloc[[1]], 5)
Out[ ]:
age weight gender temperature heart_rate resp_rate spo2 sbp dbp mbp wbc hemoglobin platelet bun cr glu Na Cl K Mg Ca P inr pt ptt bicarbonate aniongap gcs vent crrt vaso seda sofa_score ami ckd copd hyperte dm aki stroke AISAN BLACK HISPANIC OTHER WHITE unknown CCU CVICU MICU MICU/SICU NICU SICU TSICU reg pred dst
2912 62.0 103.000000 0 33.200001 78.0 16.0 97.0 89.0 51.0 58.0 10.9 8.1 55.0 101.0 8.6 70.0 138.0 105.0 4.9 2.9 7.9 6.6 2.5 27.200001 45.900002 17.0 21.0 4.0 0.0 0.0 1.0 0.0 8 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0 0.870129 True 50.399212
6348 69.0 104.000000 1 36.500000 93.0 31.0 93.0 71.0 43.0 50.0 0.9 11.5 37.0 38.0 3.0 101.0 136.0 103.0 6.1 2.3 8.9 9.7 1.6 18.100000 37.400002 15.0 28.0 12.0 1.0 1.0 1.0 1.0 3 1.0 1.0 0.0 0.0 1.0 1.0 0.0 0 0 0 0 0 1 0 0 0 0 0 0 1 0.841014 True 52.857719
12476 54.0 143.500000 0 36.610001 82.0 8.0 98.0 83.0 42.0 49.0 7.1 9.5 36.0 68.0 2.3 118.0 135.0 101.0 5.2 2.3 8.6 6.0 3.1 34.799999 47.200001 26.0 13.0 12.0 0.0 0.0 1.0 0.0 4 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0 0.681989 True 54.418489
6652 62.0 94.500000 1 36.560001 92.0 26.0 99.0 89.0 43.0 55.0 6.7 7.2 67.0 43.0 2.7 68.0 133.0 107.0 4.6 1.9 7.1 4.2 2.2 22.600000 38.599998 17.0 14.0 15.0 1.0 0.0 1.0 1.0 6 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0 0 0 0 1 0 0 0 1 0 0 0 0 0.650931 True 57.073599
792 75.0 105.900002 0 36.560001 83.0 20.0 100.0 103.0 66.0 77.0 1.1 9.3 61.0 76.0 2.4 108.0 139.0 101.0 4.4 2.6 7.7 4.5 1.4 15.500000 21.400000 23.0 15.0 15.0 1.0 0.0 0.0 0.0 5 0.0 0.0 0.0 0.0 0.0 1.0 1.0 0 0 0 0 1 0 1 0 0 0 0 0 0 0.645523 True 61.315624

SHAP ¶

Force plot ¶

Deze plot laat zien in hoe verre bepaalde features hebben bijgedragen aan de voorspelling. Een blauwe feature draagt negatief bij aan de voorspelling en een rode positief. Een langer balkje indiceert dat een feature meer bijgedragen heeft. 'f(x)' is de gemiddelde SHAP-waarde van deze voorspelling, en 'base value' is de verwachtingswaarde van alle SHAP-waarden in de set.

In [ ]:
row = 0  # `row` is de index van de huidige patient
shap.force_plot(explainer.expected_value, shap_values[row, :], data.loc[:, ~data.columns.isin(["sad"])].iloc[row, :], plot_cmap=["#FDE725", "#440154"])
Out[ ]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Waterfall ¶

De waterfall-plot is eigenlijk hetzelfde als de vorige maar dan staat elke feature los, en op volgorde van importance. Dit kan wat overzichtelijker zijn.

In [ ]:
# De SHAP-waarden in deze wijken iets af van de andere visualisaties, dit is omdat de rest is gebaseerd op `TreeExplainer` en deze `Explainer` van de SHAP library. 
# Deze plot kan het best achterwege gelaten worden omdat:
# - Kleuren kunnen niet veranderd worden
# - De plot kan niet op een TreeExplainer gemaakt worden
# - De plot komt toch al erg overeen met de decision plots
# de eerste twee punten kunnen overkomen worden door buiten de library te werken, maar dit lijkt de moeite niet waard

shap.plots.waterfall(shap_waterfall[0], max_display=14)
No description has been provided for this image

Decision plots ¶

Deze is vergelijkbaar met de waterfall-plot. Als we van beneden naar boven de lijn volgen zien we hoe iedere feature de uitkomst beinvloedt.

De decision plot wijkt wel af van de waterfall-plot in dat we op de x-as niet de cumulatieve SHAP waarde hebben, maar de uitkomst van het model (boven 0.5 is de voorspelling SAD, en daaronder NON-SAD). De grijze lijn is het gewogen gemiddelde van alle voorspellingen op de set.

In [ ]:
row = 0
shap.decision_plot(
    explainer.expected_value,
    shap_values[row, :],
    data.loc[:, ~data.columns.isin(["sad"])].iloc[0, :],
    link="logit",
    highlight=0,
    plot_color=plt.get_cmap("viridis")
)
No description has been provided for this image

Hier is de decision plot uitgebreid om de patiënt te vergelijken met andere patiënten naar keuze, in dit geval de eerste 5 patiënten uit de dataset

In [ ]:
row_current = 0
rows = [0,1,2,3,4]  # indices van de patienten waarmee we de huidige patient willen vergelijken

shap.decision_plot(
    explainer.expected_value, 
    shap_values[rows, :], 
    data.loc[:, ~data.columns.isin(["sad"])].iloc[0, :],
    link="logit", 
    highlight=row_current,
    plot_color=plt.get_cmap("viridis")
)
No description has been provided for this image

We kunnen de interactieve plot ook gebruiken voor locale uitleg; hieronder de interactieve plot voor de zelfde 5 patiënten als de vorige plot.

In [ ]:
shap.force_plot(
    explainer.expected_value, shap_values[[0,1,2,3,4], :], plot_cmap=["#FDE725", "#440154"]
)
Out[ ]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.